Converting the State Dict

The training script (train.py) doesn’t support any fancy saving/checkpointing methods, but it does optionally save the model right at the end of training into a safetensors file. In this notebook we’ll show how to load in these saved weights for downstream evaluation and usage. This should hopefully become unneeded as frameworks integrate the changes needed to make FSDP+QLoRA work natively.

As an example, let’s look at a model trained with the following command (using default settings for LoRA rank etc):

python train.py --save_model True --train_type qlora --output_dir qlora_output

We’ll load the saved state_dict, and then copy the relevant weights into a PEFT model to save via their TODO method.

Let’s start by loading the state dict. If you uncomment the print statement, you’ll see that for every linear layer that had a LoRA adapter, we have something like this:

base_model.model.model.layers.0.mlp.down_proj.base_layer.weight torch.bfloat16 torch.Size([11272192, 1])
base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight torch.bfloat16 torch.Size([8, 11008])
base_model.model.model.layers.0.mlp.down_proj.lora_B.default.weight torch.bfloat16 torch.Size([4096, 8])

The base weights are flattened and quantized 4-bit values, which we won’t need (we’ll load the original base model later), and the lora_A and lora_B adapters are the ones we’re interested in.

from safetensors import safe_open

tensors = {}
with safe_open("qlora_output/model_state_dict.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k) # loads the full tensor given a key
        # print(k, tensors[k].dtype, tensors[k].shape) # Uncomment to view

To save memory, we can delete everything but the LoRA layers:

for k in tensors:
    if 'lora' not in k: tensors[k] = None

Next, we load the base model and add a random adapter:

import torch
from transformers import LlamaForCausalLM, BitsAndBytesConfig
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

# Make sure the compute type, target modules, rank, alpha etc match!
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    use_cache=False,
    quantization_config=bnb_config
)

# Freeze
for param in model.parameters():
    param.requires_grad = False

# Add LoRA (make sure your rank (r) and alpha (lora_alpha) values match those used in training!)
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=64, lora_alpha=16, lora_dropout=0.1,
    target_modules=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]
)
model = get_peft_model(model, peft_config)

# Check out the first few keys in the state dict:
list(model.state_dict().keys())[:10]
['base_model.model.model.embed_tokens.weight',
 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight',
 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.absmax',
 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.quant_map',
 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.quant_state.bitsandbytes__nf4',
 'base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight',
 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight',
 'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight',
 'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.absmax',
 'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.quant_map']

Now, if all goes well, we can replace the randomly initialized LoRA layers with our trained ones:

new_sd = model.state_dict()
for k in new_sd:
    if 'lora' in k:
        new_sd[k] = tensors[k]

model.load_state_dict(new_sd)
<All keys matched successfully>

And now, since we have a regular PEFT model, we can save using the built-in methods:

model.save_pretrained("lora_adapters")
!ls lora_adapters
README.md  adapter_config.json  adapter_model.safetensors
# model.push_to_hub('your_repo_id') # If you want to share your model...